library(synthdid); library(ebal); library(tidyverse)

sdid_eb_est = function(d, t_c, n_cov = 0, intercept = TRUE) {
	#' SDID with entropy balancing weights
	#' @param dataframe with first column is an index, second column is treatment, n_cov columns are covariates (can be zero), and t_c columns of pre-treatment outcomes
	#' @param t_c number of pre-treatment periods
	#' @param n_cov number of covariates (0 by default)
	#' @param intercept Boolean: include fixed effects?
	#' @return list of output
  n = nrow(d); n_c = sum(d[, 2] == 0)
  n_t = n - n_c
	w = d[, 2] %>% as.matrix %>% as.numeric
	y = as.matrix(d[, (3 + n_cov):ncol(d)])
	# handle covariates
	if (n_cov > 0) {
		X = d[, (3:(2 + n_cov)), drop = FALSE] %>% as.matrix()
	} else {
		X = NULL
	}
	# handle intercept
	yy = y[, 1:t_c]
	if (intercept) {
		yy = yy - rowMeans(yy)
		yy = yy[, -1, drop = FALSE]
	}
	# balancing matrix
	X = cbind(yy, X)
  omega = ebalance(w, X)$w
	# original time weights
  lambda = synthdid:::sc.weight.fw(y[1:n_c, ],
    zeta = 1e-6, intercept = intercept, lambda = NULL,
    min.decrease = 1e-5, max.iter = 100
  )$lambda
  trend_c = t(omega / n_t) %*% y[1:n_c, ]
  trend_t = t(rep(1, n_t) / n_t) %*% y[(n_c + 1):n, ]
  mean_c0 = t(omega / n_t) %*% y[1:n_c, 1:t_c] %*% lambda
  mean_t0 = t(rep(1, n_t) / n_t) %*% y[(n_c + 1):n, 1:t_c] %*% lambda
  mean_c1 = t(omega / n_t) %*% y[1:n_c, t_c + 1]
  mean_t1 = t(rep(1, n_t) / n_t) %*% y[(n_c + 1):n, t_c + 1]
  did = as.numeric((mean_t1 - mean_c1) - (mean_t0 - mean_c0))
  return(list(
    est = did, omega = omega, lambda = lambda,
    trend_c = trend_c, trend_t = trend_t
  ))
}

sdid_eb = function(d, t_c, n_cov = 0, intercept = TRUE, boot_reps) {
  n = nrow(d)
  est_boot = vector(length = boot_reps)
  for (i in 1:boot_reps) {
    d_boot = sample_n(d, size = n, replace = TRUE)
    d_boot = arrange(d_boot, d_boot[, 2])
    est_boot[i] = sdid_eb_est(d_boot, t_c, n_cov, intercept)$est
  }
	ret = append(
    sdid_eb_est(d, t_c, n_cov, intercept),
    list(est_boot = est_boot)
  )
	class(ret) = "sdid_eb"
  return(ret)
}

summary.sdid_eb = \(x) c(est = x[1], se = sd(x$est_boot))
